#!/usr/bin/env python3
from __future__ import annotations
import argparse, pathlib, sys, os, re, importlib.util, inspect
from typing import Any, Dict, List, Tuple
import numpy as np

# ---------- repo paths ----------
MODULE_DIR = pathlib.Path(__file__).resolve().parent
REPO_ROOT  = MODULE_DIR.parent
if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))

# ---------- utilities (prefer repo's sim_utils) ----------
try:
    from sim_utils import load_config, sweep_iter, seed_all, save_csv  # type: ignore
except Exception:
    import yaml, csv, random
    def load_config(p: str) -> Dict[str, Any]:
        with open(p, "r", encoding="utf-8") as f:
            return yaml.safe_load(f)
    def sweep_iter(cfg: Dict[str, Any]):
        for b in cfg["b_values"]:
            for k in cfg["k_values"]:
                for n0 in cfg["n0_values"]:
                    for L in cfg["L_values"]:
                        yield float(b), float(k), float(n0), int(L)
    def seed_all(b: float, k: float, n0: float, L: int) -> None:
        seed = (int(round(b*1e6)) ^ int(round(k*1e6))
                ^ int(round(n0*1e6)) ^ int(L)) & 0xFFFFFFFF
        random.seed(seed); np.random.seed(seed)
    def save_csv(path: str | pathlib.Path, row: Dict[str, Any]) -> None:
        path = str(path); os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
        hdr = ["L","b","gauge","k","n0","sigma","sigma_err","string_tension","string_tension_err"]
        need_hdr = not os.path.exists(path)
        import csv as _csv
        with open(path, "a", newline="", encoding="utf-8") as f:
            w = _csv.DictWriter(f, fieldnames=hdr)
            if need_hdr: w.writeheader()
            for h in hdr: row.setdefault(h, None)
            w.writerow(row)

# ---------- path helpers ----------
def kernel_path_from_template(tpl: str, gauge: str, L: int) -> str:
    rel = tpl.format(gauge=gauge, L=L)
    return str((REPO_ROOT / rel).resolve()) if not os.path.isabs(rel) else rel

def flip_counts_path_from_template(tpl: str, L: int) -> str:
    rel = tpl.format(L=L)
    return str((REPO_ROOT / rel).resolve()) if not os.path.isabs(rel) else rel

def _import_module_from_path(p: pathlib.Path, modname="adjoint_tension_mod"):
    import importlib.util, sys as _sys
    spec = importlib.util.spec_from_file_location(modname, str(p))
    if spec and spec.loader:
        mod = importlib.util.module_from_spec(spec)  # type: ignore
        # IMPORTANT: register before exec so dataclasses works
        _sys.modules[spec.name] = mod
        spec.loader.exec_module(mod)  # type: ignore
        return mod
    return None

# ---------- locate compute callable ----------
def find_compute_callable(cfg: Dict[str, Any]):
    av = dict(cfg.get("adjoint_volume", {}))
    explicit_path = av.get("compute_module")
    explicit_func = av.get("compute_function")
    explicit_alt  = av.get("compute_function_alt", "compute_tension_with_err")

    # candidate files (explicit first)
    candidates: List[pathlib.Path] = []
    if explicit_path:
        p = (REPO_ROOT / explicit_path) if not os.path.isabs(explicit_path) else pathlib.Path(explicit_path)
        candidates.append(p)

    # scan both module/orig and repo-root/orig
    roots = [MODULE_DIR / "orig", REPO_ROOT / "orig"]
    for r in roots:
        if r.exists():
            for p in r.rglob("*.py"):
                try:
                    txt = p.read_text(encoding="utf-8", errors="ignore")
                except Exception:
                    continue
                if re.search(r"\bdef\s+(compute_tension_with_err|compute_string_tension|compute_tension)\s*\(", txt):
                    candidates.append(p)

    seen = set()
    for p in candidates:
        if not p.exists() or p in seen: continue
        seen.add(p)
        mod = _import_module_from_path(p, modname=f"adj_tension_{abs(hash(p))}")
        if not mod: continue

        # Preference order: explicit_alt -> explicit_func -> with_err -> string_tension -> tension
        names: List[str] = []
        if explicit_alt:  names.append(explicit_alt)
        if explicit_func: names.append(explicit_func)
        names += ["compute_tension_with_err", "compute_string_tension", "compute_tension"]

        for name in names:
            if hasattr(mod, name):
                fn = getattr(mod, name)
                sig = inspect.signature(fn)
                return fn, sig
    raise RuntimeError(
        "Could not locate compute_* function. Set adjoint_volume.compute_module and "
        "adjoint_volume.compute_function(_alt) in your YAML."
    )

# ---------- kwarg selection ----------
def select_kwargs(sig: inspect.Signature, context: Dict[str, Any]) -> Dict[str, Any]:
    """Pick only kwargs the target function accepts, with obvious renames."""
    params = sig.parameters
    out: Dict[str, Any] = {}
    mapping = {
        # physics args
        "L": "L",
        "k_exp": "k",
        "k": "k",
        "b": "b",
        "n0": "n0",
        "gauge": "gauge",
        "representation": "representation",
        "rep": "representation",
        "kernel_path": "kernel_path",
        "flip_counts_path": "flip_counts_path",
        "fit_range": "fit_range",
        "volumes": "volumes",
        "loop_sizes": "loop_sizes",
        "loops": "loop_sizes",
        "seed": "seed",
        "rng": "rng",
        # bootstrap / guards (only passed if the function accepts them)
        "bootstrap_reps": "bootstrap_reps",
        "bootstrap_block": "bootstrap_block",
        "bootstrap_seed": "bootstrap_seed",
        "kernel_hash_guard": "kernel_hash_guard",
    }
    for pname, src in mapping.items():
        if pname in params and (src in context) and (context[src] is not None):
            out[pname] = context[src]
    return out

# ---------- main ----------
def main() -> None:
    ap = argparse.ArgumentParser(description="Adjoint volume sweep (normalized per-plaquette, with optional bootstrap)")
    ap.add_argument("--config","-c", required=True)
    ap.add_argument("--output-dir","-o", required=True)
    args = ap.parse_args()

    cfg = load_config(args.config)
    out_dir = pathlib.Path(args.output_dir); out_dir.mkdir(parents=True, exist_ok=True)
    out_csv = out_dir / "adjoint_volume_summary.csv"

    av = dict(cfg.get("adjoint_volume", {}))
    gauges = av.get("gauge_groups", ["SU2","SU3"])
    fit_range = av.get("fit_range")
    kernel_tpl_by_g = av.get("kernel_paths", {})
    flip_tpl = av.get("flip_counts_path_template", cfg.get("flip_counts_path_template",""))
    volumes_cfg = av.get("volumes")
    loop_sizes_cfg = list(cfg.get("loop_sizes", [])) or [1,2,3,4,5,6]
    if not flip_tpl:
        raise ValueError("flip_counts_path_template missing (adjoint_volume or global)")

    # Bootstrap / guard controls (read from YAML; only passed if the function supports them)
    bootstrap_reps  = int(av.get("bootstrap_reps", 400))
    bootstrap_block = int(av.get("bootstrap_block", 0))    # 0 = auto
    bootstrap_seed  = int(av.get("bootstrap_seed", 1337))
    kernel_hash_guard = bool(av.get("kernel_hash_guard", True))

    compute_fn, compute_sig = find_compute_callable(cfg)

    for b, k, n0, L in sweep_iter(cfg):
        seed_all(b, k, n0, L)
        for g in gauges:
            ktpl = kernel_tpl_by_g.get(g, cfg.get("kernel_path_template", {}).get(g))
            if not ktpl:
                raise ValueError(f"No kernel template found for gauge {g}")
            kernel_path = kernel_path_from_template(ktpl, g, L)
            flip_path   = flip_counts_path_from_template(flip_tpl, L)
            if not os.path.exists(kernel_path):   raise FileNotFoundError(f"Kernel not found: {kernel_path}")
            if not os.path.exists(flip_path):     raise FileNotFoundError(f"Flip counts not found: {flip_path}")

            context: Dict[str, Any] = {
                "L": int(L),
                "b": float(b),
                "k": float(k),
                "n0": float(n0),
                "gauge": str(g),
                "representation": "adjoint",
                "kernel_path": kernel_path,
                "flip_counts_path": flip_path,
                "fit_range": list(fit_range) if fit_range else None,
                "loop_sizes": loop_sizes_cfg,
                "volumes": list(volumes_cfg) if volumes_cfg else [int(L)],
                "seed": None,
                # bootstrap / guard (will be filtered out if fn doesn't accept them)
                "bootstrap_reps": bootstrap_reps,
                "bootstrap_block": bootstrap_block,
                "bootstrap_seed": bootstrap_seed,
                "kernel_hash_guard": kernel_hash_guard,
            }

            # Prepare kwargs the target actually accepts
            kwargs = select_kwargs(compute_sig, context)

            # Call the compute function; handle scalar or (sigma, sigma_err, meta)
            result = compute_fn(**kwargs)
            if isinstance(result, (tuple, list)) and len(result) >= 2:
                sigma_raw = float(result[0])
                sigma_raw_err = float(result[1])
            else:
                sigma_raw = float(result)
                sigma_raw_err = 0.0

            # Normalize to per-plaquette units
            area_scale = 2.0 * (L**2)
            string_tension     = sigma_raw     / area_scale
            string_tension_err = sigma_raw_err / area_scale

            save_csv(out_csv, {
                "L": L, "b": b, "gauge": g, "k": k, "n0": n0,
                "sigma": float(sigma_raw),
                "sigma_err": float(sigma_raw_err),
                "string_tension": float(string_tension),
                "string_tension_err": float(string_tension_err),
            })

    print(f"✅ adjoint volume summary → {out_csv}")

if __name__ == "__main__":
    main()
